/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/*!
* \file mx_matmul_utils.h
* \brief
*/

#ifndef IMPL_MATMUL_UTILS_MX_MATMUL_UTILS_H
#define IMPL_MATMUL_UTILS_MX_MATMUL_UTILS_H


#include "matmul_type_def.h"
#include "../feature_trait/matmul_feature_trait.h"
namespace AscendC {

template <typename T, typename U>
constexpr bool IsSameTypeV = AscendC::IsSameType<T, U>::value;

template <typename T, typename... Others>
struct IsTypeOneOf {
    static constexpr bool value = false;
};

template <typename T, typename First, typename... Others>
struct IsTypeOneOf<T, First, Others...> {
    static constexpr bool value = IsSameTypeV<T, First> || IsTypeOneOf<T, Others...>::value;
};

template <typename T, typename... Others>
constexpr bool IsTypeOneOfV = IsTypeOneOf<T, Others...>::value;

template <typename T> struct GetMmDstType {
    using Type = T;
};

template <typename T, bool isMxType = false>
struct GetL0DataType {
    using Type = T;
};

#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3101
template <> struct GetMmDstType<fp8_e4m3fn_t> {
    using Type = float;
};

template <> struct GetMmDstType<fp8_e5m2_t> {
    using Type = float;
};

template <> struct GetMmDstType<hifloat8_t> {
    using Type = float;
};

template <> struct GetMmDstType<fp4x2_e2m1_t> {
    using Type = float;
};

template <> struct GetMmDstType<fp4x2_e1m2_t> {
    using Type = float;
};

template <> struct GetL0DataType<fp8_e5m2_t, true> {
    using Type = AscendC::mx_fp8_e5m2_t;
};

template <> struct GetL0DataType<fp8_e5m2_t, false> {
    using Type = fp8_e5m2_t;
};

template <> struct GetL0DataType<fp8_e4m3fn_t, true> {
    using Type = AscendC::mx_fp8_e4m3_t;
};

template <> struct GetL0DataType<fp8_e4m3fn_t, false> {
    using Type = fp8_e4m3fn_t;
};
#endif

template <typename SrcT>
__aicore__ inline constexpr static int32_t AuxGetC0Size()
{
    if (sizeof(SrcT) == sizeof(float)) {
        return Impl::B32_C0SIZE;
    }
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3101
    else if (IsTypeOneOfV<SrcT, int8_t, hifloat8_t, fp8_e4m3fn_t, fp8_e5m2_t, fp8_e8m0_t>)
    {
        return Impl::B8_C0SIZE;
    } else if (IsTypeOneOfV<SrcT, int4b_t, fp4x2_e1m2_t, fp4x2_e2m1_t>) {
        return Impl::B4_C0SIZE;
    }
#else
    else if (IsSameType<SrcT, int8_t>::value) {
        return Impl::B8_C0SIZE;
    } else if (IsSameType<SrcT, int4b_t>::value) {
        return Impl::B4_C0SIZE;
    }
#endif
    return Impl::B16_C0SIZE;
}

template <typename SrcT>
__aicore__ inline constexpr bool IsSupportB32()
{
    if (IsTypeOneOfV<SrcT, int32_t, float>) {
        return true;
    }
    return false;
}

template <typename SrcT>
__aicore__ inline constexpr bool IsSupportB8()
{
    if (IsSameTypeV<SrcT, int8_t>) {
        return true;
    }
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3101
    if (IsTypeOneOfV<SrcT, hifloat8_t, fp8_e4m3fn_t, fp8_e5m2_t>) {
        return true;
    }
#endif
    return false;
}

template <typename SrcT>
__aicore__ inline constexpr bool IsSupportB4()
{
    if (IsSameTypeV<SrcT, int4b_t>) {
        return true;
    }
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3101
    if (IsTypeOneOfV<SrcT, fp4x2_e1m2_t, fp4x2_e2m1_t>) {
        return true;
    }
#endif
    return false;
}

template <typename SrcT>
__aicore__ inline constexpr bool IsSupportMxFp4()
{
#if defined(__DAV_C310__) || defined(__DAV_310R6__)
    if (IsTypeOneOfV<SrcT, fp4x2_e1m2_t, fp4x2_e2m1_t>) {
        return true;
    }
#endif
    return false;
}

template <typename SrcT>
__aicore__ inline constexpr bool IsSupportMxFp8()
{
#if defined(__DAV_C310__) || defined(__DAV_310R6__)
    if (IsTypeOneOfV<SrcT, fp8_e4m3fn_t, fp8_e5m2_t>) {
        return true;
    }
#endif
    return false;
}

template <typename T>
__aicore__ inline constexpr static bool IsNeedC0Align()
{
    return IsSupportB8<T>() || IsSupportB4<T>();
}

#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3101
constexpr uint8_t INTRA_MODE = 4;
template <typename INPUT_TYPE>
__aicore__ constexpr bool PhyMxScalePosIsL1()
{
    if constexpr (HasScalePosition<INPUT_TYPE>::value) {
        return PhyPosIsL1(INPUT_TYPE::scalePosition);
    }
    return false;
}

template <typename INPUT_TYPE>
__aicore__ constexpr bool PhyMxScalePosIsUB()
{
    if constexpr (HasScalePosition<INPUT_TYPE>::value) {
        return PhyPosIsUB(INPUT_TYPE::scalePosition);
    }
    return false;
}

template <typename INPUT_TYPE>
__aicore__ constexpr bool PhyMxScalePosIsGM()
{
    if constexpr (HasScalePosition<INPUT_TYPE>::value) {
        return PhyPosIsGM(INPUT_TYPE::scalePosition);
    }
    return false;
}
#endif

template <typename T>
__aicore__ constexpr int32_t GetBitSize()
{
    if constexpr (std::is_arithmetic<T>::value) {
        return sizeof(T) * ONE_BYTE_BIT_SIZE;
    }
    if constexpr (IsSameTypeV<T, AscendC::int4b_t>) {
        return ONE_BYTE_BIT_SIZE / 2;
    }
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3101
    if constexpr (IsTypeOneOfV<T, fp8_e8m0_t, hifloat8_t, fp8_e4m3fn_t, fp8_e5m2_t>) {
        return ONE_BYTE_BIT_SIZE;
    }
    if constexpr (IsTypeOneOfV<T, fp4x2_e2m1_t, fp4x2_e1m2_t>) {
        return ONE_BYTE_BIT_SIZE / 2;
    }
#endif
#if __NPU_ARCH__ == 5102
    if (IsSameTypeV<T, AscendC::int2b_t>) {
        return ONE_BYTE_BIT_SIZE / 4;
    }
#endif

    return ONE_BYTE_BIT_SIZE * 2;
}

template <typename T>
constexpr bool IsScaleTransWithInlv = (HasScalePosition<T>::value && PhyPosIsGM(T::pos) &&
    (T::format == CubeFormat::ND) && PhyPosIsL1(T::scalePosition));

template <typename A_TYPE, typename B_TYPE, const auto& MM_CFG>
__aicore__ inline constexpr bool IsL1BNeedTrans()
{
    if constexpr (!Impl::Detail::MatmulFeatureTrait<MM_CFG>::IsMmadInstrSupportAntiQuant()) {
        if constexpr (GetBitSize<typename B_TYPE::T>() == GetBitSize<typename A_TYPE::T>()) {
            return false;
        } else {
            return true;
        }
    }
    return false;
}

template <typename A_TYPE, typename B_TYPE, const auto& MM_CFG>
__aicore__ inline constexpr auto GetTransBDataType()
{
    if constexpr (HasScalePosition<A_TYPE>::value) {
        B_TYPE mxBType;
        return mxBType;
    }
#if __NPU_ARCH__ == 5102
    else if constexpr(DecompMode(MM_CFG) == DecompressionMode::DECOMP_1bitTo4bit || DecompMode(MM_CFG) == DecompressionMode::DECOMP_2bitTo4bit) {
        MatmulType<TPosition::GM, CubeFormat::NZ, int4b_t> bType;
        return bType;
    } else if constexpr(DecompMode(MM_CFG) == DecompressionMode::DECOMP_4bitTo8bit) {
        MatmulType<TPosition::GM, CubeFormat::NZ, int8_t> bType;
        return bType;
    }
#else
    else if constexpr (IsL1BNeedTrans<A_TYPE, B_TYPE, MM_CFG>()) {
        A_TYPE aType;
        return aType;
    }
#endif
    else {
        B_TYPE bType;
        return bType;
    }
}
template <typename INPUT_TYPE>
__aicore__ inline constexpr bool IsScaleTag() {
    return INPUT_TYPE::TAG == InputTypeTag::scaleA || INPUT_TYPE::TAG == InputTypeTag::scaleB;
}

template <typename INPUT_TYPE>
__aicore__ inline constexpr bool InputPhyPosIsGM()
{
    if constexpr (IsScaleTag<INPUT_TYPE>()) {
        return PhyPosIsGM(INPUT_TYPE::scalePosition);
    } else {
        return PhyPosIsGM(INPUT_TYPE::pos);
    }
}

template <typename INPUT_TYPE>
__aicore__ inline constexpr bool InputPhyPosIsL1()
{
    if constexpr (IsScaleTag<INPUT_TYPE>()) {
        return PhyPosIsL1(INPUT_TYPE::scalePosition);
    } else {
        return PhyPosIsL1(INPUT_TYPE::pos);
    }
}

template <typename INPUT_TYPE>
__aicore__ inline constexpr bool InputPhyPosIsUB()
{
    if constexpr (IsScaleTag<INPUT_TYPE>()) {
        return PhyPosIsUB(INPUT_TYPE::scalePosition);
    } else {
        return PhyPosIsUB(INPUT_TYPE::pos);
    }
}

#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 3101
template <typename T>
constexpr bool SupportMXFP8 = IsTypeOneOfV<T, fp8_e4m3fn_t, fp8_e5m2_t, fp4x2_e2m1_t, fp4x2_e1m2_t>;
#else
template <typename T>
constexpr bool SupportMXFP8 = false;
#endif
template <typename AType, typename BType, const auto& MM_CFG>
constexpr bool IsMxDisableUnitFlag = (EnUnitFlag(MM_CFG) && HasScalePosition<AType>::value && (AType::isTrans ||
    !BType::isTrans || IsStaticPaddingEnable(MM_CFG)) && SupportMXFP8<typename AType::T>);
} // namespace AscendC
#endif // _MATMUL_UTILS_H_
